import torch
import torch.distributions
from torchvision import datasets, transforms
from torch.utils.data import Dataset
from utils.datasets.tinyImages import _load_tiny_image, _preload_tiny_images
from utils.datasets.paths import get_CIFAR10_path, get_CIFAR100_path
from utils.datasets.cifar_augmentation import get_cifar10_augmentation
import os

def get_CIFAR10_subset(split, labeled_per_class, batch_size=128, augm_type='default', shuffle=True,
                       cutout_window=16, size=32, num_workers=8, id_config=None):

    augm_config = {}
    transform = get_cifar10_augmentation(augm_type, cutout_window, out_size=size, config_dict=augm_config)

    if labeled_per_class == 400:
        samples_per_class_per_split = [400, 500, 4100]
    elif labeled_per_class == 200:
        samples_per_class_per_split = [200, 700, 4100]
    elif labeled_per_class == 100:
        samples_per_class_per_split = [100, 500, 4400]
    else:
        raise NotImplementedError()

    dataset = CIFARSubset(split, 'CIFAR10', samples_per_class_per_split, transform, generate_idcs=False)
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)

    if id_config is not None:
        id_config['Dataset'] = 'CIFAR10-Subset'
        id_config['Batch size'] = batch_size
        id_config['Labeled samples'] = labeled_per_class
        id_config['Augmentation'] = augm_config

    return loader

def get_CIFAR100_subset(split, samples_per_class, batch_size=128,  augm_type='default', shuffle=True,
                       cutout_window=16, size=32, num_workers=8, id_config=None):

    augm_config = {}
    transform = get_cifar10_augmentation(augm_type, cutout_window, out_size=size, config_dict=augm_config)
    dataset = CIFARSubset(split, 'CIFAR100', samples_per_class, transform, generate_idcs=True)

    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)

    if id_config is not None:
        id_config['Dataset'] = 'CIFAR100-Subset'
        id_config['Batch size'] = batch_size
        id_config['Samples per class'] = samples_per_class
        id_config['Augmentation'] = augm_config

    return loader

def _get_idcs_filename(dataset_name, points_per_class):
    train_filename = f'ssl_train_{dataset_name}_{points_per_class}.pt'
    val_filename = f'ssl_val_{dataset_name}_{points_per_class}.pt'
    unlabeled_filename = f'ssl_unlabeled_{dataset_name}_{points_per_class}.pt'
    return train_filename, val_filename, unlabeled_filename

def _generate_subset(dataset, num_classes, points_per_split_per_class):
    splits = []
    for points_per_class in points_per_split_per_class:
        idcs = torch.zeros(num_classes * points_per_class, dtype=torch.long)
        splits.append(idcs)

    labels_tensor = torch.LongTensor(dataset.targets)
    total_points_per_class = sum(points_per_split_per_class)

    for i in range(num_classes):
        class_i_idcs = torch.nonzero(labels_tensor==i, as_tuple=False).squeeze()
        assert len(class_i_idcs) == total_points_per_class

        rand_idcs = torch.randperm(len(class_i_idcs))

        idx = 0
        for split_idx, points_per_class in enumerate(points_per_split_per_class):
            split_idcs = splits[split_idx]
            chosen_idcs = class_i_idcs[rand_idcs[idx:(idx+points_per_class)]]

            split_idcs[i*points_per_class:(i+1)*points_per_class] = chosen_idcs
            idx += points_per_class

            assert torch.sum(labels_tensor[chosen_idcs]  == i) == points_per_class

    return splits

class CIFARSubset(Dataset):
    def __init__(self, split, dataset, samples_per_class_per_split, transform, generate_idcs=False):
        if dataset.lower() == 'cifar10':
            self.cifar = datasets.CIFAR10(get_CIFAR10_path(), train=True, transform=transform)
            self.num_classes = 10
        elif dataset.lower() == 'cifar100':
            self.cifar = datasets.CIFAR100(get_CIFAR100_path(), train=True, transform=transform)
            self.num_classes = 100
        else:
            raise NotImplementedError()

        train_filename, val_filename, unlabeled_filename = _get_idcs_filename(dataset.lower(), samples_per_class_per_split[0])
        exists = os.path.isfile(train_filename) &  os.path.isfile(val_filename) & os.path.isfile(unlabeled_filename)

        if not exists and generate_idcs:
            print(f'Generating indices for splits {samples_per_class_per_split}')
            train_idcs, val_idcs, unlabeled_idcs = _generate_subset(self.cifar, self.num_classes, samples_per_class_per_split)
            torch.save(train_idcs, train_filename)
            torch.save(val_idcs, val_filename)
            torch.save(unlabeled_idcs, unlabeled_filename)
        elif exists:
            pass
        else:
            raise ValueError('Idx file does not exist and generate not set to True')


        if split == 'train':
            filename = train_filename
        elif split == 'val':
            filename = val_filename
        elif split == 'unlabeled':
            filename = unlabeled_filename
        else:
            raise ValueError()

        idcs = torch.load(filename)
        self.idcs = idcs
        self.length = len(idcs)

        print(f'{dataset} subset - {split} split - {samples_per_class_per_split} labeled per class - Split size {self.length} ')

    @property
    def targets(self):
        t = []
        for i in self.idcs:
            t.append(self.cifar.targets[i])
        return t

    def __getitem__(self, index):
        ii = self.idcs[index]
        return self.cifar[ii]

    def __len__(self):
        return self.length
